# encoding: utf-8

import torch


class MyAbsLossFn(torch.nn.Module):
    def __init__(self):
        super(MyAbsLossFn, self).__init__()

    def forward(self, out, label):
        loss = (out.argmax(1) != label).sum().cpu().data.numpy()
        return loss
